"""
pipeline for ner and nre: bert-crf、PURE-relation
"""

import json
import os.path
import re
import sys

sys.path.append("/data/zzj/nlp/nre/PURE/")

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from seqeval.metrics.sequence_labeling import get_entities
from tqdm import tqdm

from models.bert_for_ner import BertCrfForNer
from processors.ner_seq import EMRImageReportProcessor
from relation.models import BertForRelationApprox


class NerDataset(Dataset):
    def __init__(self, data: list = None, tokenizer: PreTrainedTokenizerFast = None):
        self.tokenizer = tokenizer
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        inputs = self.tokenizer.encode_plus(text,
                                            add_special_tokens=True,
                                            return_offsets_mapping=True,
                                            return_tensors="pt")
        input_len = sum(inputs['attention_mask'][0])

        return inputs["input_ids"][0], inputs["attention_mask"][0], inputs["token_type_ids"][0], \
               inputs["offset_mapping"][0], input_len, text


def collate_fn(batch):
    input_ids, attention_mask, token_type_ids, offset_mapping, input_len, texts = zip(*batch)
    input_ids = pad_sequence(input_ids, batch_first=True)
    attention_mask = pad_sequence(attention_mask, batch_first=True)
    token_type_ids = pad_sequence(token_type_ids, batch_first=True)
    offset_mapping = pad_sequence(offset_mapping, batch_first=True)
    input_len = torch.stack(input_len)

    return input_ids, attention_mask, token_type_ids, offset_mapping, input_len, texts


def get_special_token(w, unused_tokens=True):
    if w not in special_tokens:
        if unused_tokens:
            special_tokens[w] = "[unused%d]" % (len(special_tokens) + 1)
        else:
            special_tokens[w] = ('<' + w + '>').lower()
    return special_tokens[w]


def get_attention_mask(input_mask_):
    attention_mask_ = []
    for _, from_mask in enumerate(input_mask_):
        attention_mask_i = []
        for to_mask in input_mask_:
            if to_mask <= 1:
                attention_mask_i.append(to_mask)
            elif from_mask == to_mask and from_mask > 0:
                attention_mask_i.append(1)
            else:
                attention_mask_i.append(0)
        attention_mask_.append(attention_mask_i)

    return attention_mask_


class TextSplit:
    def __init__(self, max_length, context_window):
        self.max_length = max_length
        self.context_window = context_window
        self.split_regex = re.compile(r"[!。！?？]")

    def split_text_re(self, text):
        data = []
        start = 0
        for m in self.split_regex.finditer(text):
            data.append((text[start:m.end()], start, m.end()))
            start = m.end()
        else:
            if start < len(text):
                data.append((text[start:], start, len(text)))

        return data

    def split_text_limited_length(self, text):
        start = 0
        data = []
        while start < len(text):
            end = start + self.max_length
            sub_text = text[start: end]
            data.append((sub_text, start, start + len(sub_text)))
            if end > len(text):
                break
            start = end - self.context_window

        return data

    def split_into_short_samples_test(self, text):
        """
        在预测前对文本按照指定字符进行划分
        """
        new_sample_list = []
        for text_id, (sub_text, sub_start, sub_end) in enumerate(self.split_text_re(text), start=1):
            for sub_text_, sub_start_, sub_end_ in self.split_text_limited_length(sub_text):
                new_sample = {
                    "id": text_id,
                    "text": sub_text_,
                    "char_offset": sub_start + sub_start_,
                }
                new_sample_list.append(new_sample)

        return new_sample_list

    def split(self, text):
        data = []
        start = 0
        for m in self.split_regex.finditer(text):
            curr_text = text[start:m.end()]
            if len(curr_text) > self.max_length:
                for curr_sub_text, sub_start, sub_end in self.split_text_limited_length(curr_text):
                    data.append((curr_sub_text, sub_start + start, sub_end + start))
            else:
                data.append((curr_text, start, m.end()))
            start = m.end()
        else:
            if start < len(text):
                curr_text = text[start:]
                if len(curr_text) > self.max_length:
                    for curr_sub_text, sub_start, sub_end in self.split_text_limited_length(curr_text):
                        data.append((curr_sub_text, sub_start + start, sub_end + start))
                else:
                    data.append((curr_text, start, len(text)))

        return data


class Event:
    def __init__(self, value, start, end, tag, score=0.):
        self.value = value
        self.start = start
        self.end = end
        self.tag = tag
        self.score = score

    @property
    def offset(self):
        return self.start, self.end

    def __eq__(self, other):
        if hasattr(other, "start") and hasattr(other, "end"):
            return self.start == other.start and self.end == other.end

        return False

    def overlap(self, other):
        return max(self.start, other.start) <= min(self.end, other.end)

    def __str__(self):
        return "%s(%r, %r, %r, %r, %r)" % \
               (self.__class__.__name__, self.value, self.start, self.end, self.tag, self.score)

    def __repr__(self):
        return "%s(%r, %r, %r, %r, %r)" % \
               (self.__class__.__name__, self.value, self.start, self.end, self.tag, self.score)

    def __len__(self):
        return self.end - self.start + 1


def cal_cpg(pred_set, gold_set, cpg):
    '''
    cpg is a list: [correct_num, pred_num, gold_num]
    '''
    for mark_str in pred_set:
        if mark_str in gold_set:
            cpg[0] += 1
    cpg[1] += len(pred_set)
    cpg[2] += len(gold_set)


def get_prf_scores(correct_num, pred_num, gold_num):
    minimini = 1e-12
    precision = correct_num / (pred_num + minimini)
    recall = correct_num / (gold_num + minimini)
    f1 = 2 * precision * recall / (precision + recall + minimini)
    return precision, recall, f1


pred_path = "/data/zzj/nlp/mtl/TPlinker-joint-extraction/data4bert/image_report_0225/valid_data.json"
device = torch.device("cuda:2")
ner_path = "/data/zzj/nlp/ner/BERT-NER-Pytorch/outputs/emr-image-report_output/bert-crf"
ner_model = BertCrfForNer.from_pretrained(ner_path)
ner_model.to(device)
ner_model.eval()
tok = AutoTokenizer.from_pretrained(ner_path)
ner_label_list = EMRImageReportProcessor().get_labels()
ner_id2label = {i: label for i, label in enumerate(ner_label_list)}

nre_path = "/data/zzj/nlp/nre/PURE/data/emr-image-report/relation_approx/output-chinese-bert-wwm-ext-e50"
with open(os.path.join(nre_path, "label_list.json"), 'r') as f:
    nre_label_list = json.load(f)
    nre_label_pairs = [('修饰词', '修饰', '未见疾病'), ('修饰词', '修饰', '疾病'), ('修饰词', '修饰', '病灶'), ('修饰词', '修饰', '症状'),
                       ('修饰词', '修饰', '部位'), ('指代词', '侵犯部位', '指代词'), ('指代词', '侵犯部位', '部位'), ('指代词', '指代', '指代词'),
                       ('指代词', '检测项目', '检测项目'), ('指代词', '病变部位', '指代词'), ('指代词', '病变部位', '部位'), ('指代词', '病灶倾向', '疾病'),
                       ('指代词', '结果', '数值'), ('指代词', '结果', '症状'), ('检查名称', '指代', '检测项目'), ('检查名称', '结果', '症状'),
                       ('检测项目', '指代', '指代词'), ('检测项目', '结果', '数值'), ('检测项目', '结果', '病灶'), ('检测项目', '结果', '症状'),
                       ('疾病', '指代', '指代词'), ('疾病', '病变部位', '部位'), ('病灶', '侵犯部位', '指代词'), ('病灶', '侵犯部位', '部位'),
                       ('病灶', '指代', '指代词'), ('病灶', '指代', '病灶'), ('病灶', '指代', '部位'), ('病灶', '检测项目', '检测项目'),
                       ('病灶', '病变部位', '指代词'), ('病灶', '病变部位', '病灶'), ('病灶', '病变部位', '部位'), ('病灶', '病灶倾向', '疾病'),
                       ('病灶', '症状表现部位', '指代词'), ('病灶', '症状表现部位', '部位'), ('病灶', '结果', '数值'), ('症状', '修饰', '未见疾病'),
                       ('症状', '修饰', '疾病'), ('症状', '修饰', '部位'), ('症状', '病变部位', '指代词'), ('症状', '病变部位', '部位'),
                       ('症状', '症状表现部位', '指代词'), ('症状', '症状表现部位', '疾病'), ('症状', '症状表现部位', '病灶'), ('症状', '症状表现部位', '部位'),
                       ('部位', '修饰', '部位'), ('部位', '指代', '指代词'), ('部位', '指代', '部位'), ('部位', '检测项目', '检测项目'),
                       ('部位', '症状表现部位', '部位'), ('部位', '结果', '数值'), ('部位', '结果', '病灶')]
    ent_pair_type = []
    for head_type, _, tail_type in nre_label_pairs:
        ent_pair_type.append((head_type, tail_type))

with open(os.path.join(nre_path, "special_tokens.json"), "r") as f:
    special_tokens = json.load(f)
nre_label2id = {label: i for i, label in enumerate(nre_label_list)}
nre_id2label = {i: label for i, label in enumerate(nre_label_list)}
max_split_length = 256
max_seq_length = max_split_length + 2 + 25 * 4
context_window = 50
num_labels = len(nre_label_list)
nre_model = BertForRelationApprox.from_pretrained(nre_path, num_rel_labels=num_labels)
nre_model.to(device)
nre_model.eval()

text_split = TextSplit(max_split_length, context_window)

# "correct", "pred", "gt"
ent_cpg = [0, 0, 0]
rel_cpg = [0, 0, 0]
with open(pred_path, encoding='utf8') as f:
    data = json.load(f)
for item in tqdm(data, desc="Predicting"):
    text = item["text"]
    gt_entity_list = item["entity_list"]
    gt_relation_list = item["relation_list"]
    gt_entity_set = {f"{ent['char_span'][0]},{ent['char_span'][1]}-{ent['text']}-{ent['type'].strip()}" for ent in
                     gt_entity_list}
    gt_relation_set = {
        f"{rel['subj_char_span'][0]},{rel['subj_char_span'][1]}-{rel['subject']}|{rel['predicate']}|" \
        f"{rel['obj_char_span'][0]},{rel['obj_char_span'][1]}-{rel['object']}"
        for rel in gt_relation_list
    }

    # todo: ner
    input_ids = []
    attention_mask = []
    token_type_ids = []
    sent_starts = []
    offset_mappings = []
    text_offset_list = text_split.split(text)
    for sub_text, sub_start, sub_end in text_offset_list:
        inputs = tok.encode_plus(sub_text,
                                 add_special_tokens=True,
                                 return_offsets_mapping=True,
                                 return_tensors="pt")
        input_ids.append(inputs["input_ids"][0])
        attention_mask.append(inputs["attention_mask"][0])
        token_type_ids.append(inputs["token_type_ids"][0])
        offset_mappings.append(inputs["offset_mapping"][0])
        sent_starts.append(sub_start)

    input_ids = pad_sequence(input_ids, batch_first=True).to(device)
    attention_mask = pad_sequence(attention_mask, batch_first=True).to(device)
    token_type_ids = pad_sequence(token_type_ids, batch_first=True).to(device)
    with torch.no_grad():
        outputs = ner_model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
    logits = outputs[0]
    scores = torch.softmax(logits, dim=-1).max(dim=-1)[0].cpu().numpy().tolist()
    tags = ner_model.crf.decode(logits, attention_mask)
    tags = tags.squeeze(0).cpu().numpy().tolist()
    ner_events = []
    for sent_tags, sent_scores, sent_start, offset_mapping in zip(tags, scores, sent_starts, offset_mappings):
        sent_tags = [ner_id2label[t] for t in sent_tags]
        for tag, tok_s, tok_e in get_entities(sent_tags):
            avg_score = sum(sent_scores[tok_s: tok_e + 1]) / (tok_e + 1 - tok_s)
            if avg_score < 0.8:
                continue
            char_s = int(offset_mapping[tok_s][0]) + sent_start
            char_e = int(offset_mapping[tok_e][1]) + sent_start
            curr_event = Event(text[char_s: char_e], char_s, char_e - 1, tag, avg_score)
            for i, event in enumerate(ner_events):
                if event.overlap(curr_event):
                    if len(curr_event) > len(event):
                        ner_events[i] = curr_event

                    break
            else:
                ner_events.append(curr_event)

    pred_entity_set = {f"{event.start},{event.end + 1}-{event.value}-{event.tag}" for event in ner_events}

    # # fixme: use gt data test
    # ner_events = []
    # for ent in gt_entity_list:
    #     start, end = ent['char_span']
    #     event = Event(ent['text'], start, end - 1, ent['type'].strip(), 1.)
    #     ner_events.append(event)

    # todo: nre
    pair_events = []
    all_input_ids = []
    all_position_ids = []
    all_input_mask = []
    all_segment_ids = []
    all_sub_obj_ids = []
    for sub_text, sub_start, sub_end in text_offset_list:
        curr_ner_events = [event for event in ner_events if event.start >= sub_start and event.end < sub_end]
        inputs = tok.encode_plus(sub_text, add_special_tokens=True, return_offsets_mapping=True)
        tokens = tok.convert_ids_to_tokens(inputs["input_ids"])
        token_start = token_end = {
            j: word_offset
            for word_offset, (char_start_offset, char_end_offset), in enumerate(inputs["offset_mapping"])
            for j in range(char_start_offset, char_end_offset)
        }

        num_tokens = len(tokens)
        position_ids = list(range(len(tokens)))
        marker_mask = 1
        input_mask = [1] * len(tokens)
        sub_obj_ids = []
        for head_event in curr_ner_events:
            for tail_event in curr_ner_events:
                if head_event.offset == tail_event.offset:
                    continue
                elif (head_event.tag, tail_event.tag) not in ent_pair_type:
                    continue

                SUBJECT_START_NER = get_special_token("SUBJ_START=%s" % head_event.tag)
                SUBJECT_END_NER = get_special_token("SUBJ_END=%s" % head_event.tag)
                OBJECT_START_NER = get_special_token("OBJ_START=%s" % tail_event.tag)
                OBJECT_END_NER = get_special_token("OBJ_END=%s" % tail_event.tag)

                if len(tokens) + 4 > max_seq_length:
                    input_ids = torch.tensor(tok.convert_tokens_to_ids(tokens), dtype=torch.long)
                    position_ids = torch.tensor(position_ids, dtype=torch.long)
                    segment_ids = torch.tensor([0] * len(input_ids), dtype=torch.long)
                    input_mask = torch.tensor(get_attention_mask(input_mask), dtype=torch.long)
                    sub_obj_ids = torch.tensor(sub_obj_ids, dtype=torch.long)
                    all_input_ids.append(input_ids)
                    all_position_ids.append(position_ids)
                    all_input_mask.append(input_mask)
                    all_segment_ids.append(segment_ids)
                    all_sub_obj_ids.append(sub_obj_ids)

                    tokens = tokens[:num_tokens]
                    position_ids = list(range(len(tokens)))
                    marker_mask = 1
                    input_mask = [1] * len(tokens)
                    sub_obj_ids = []

                pair_events.append([head_event, tail_event])
                tokens = tokens + [SUBJECT_START_NER, SUBJECT_END_NER, OBJECT_START_NER, OBJECT_END_NER]
                position_ids = position_ids + [token_start[head_event.start - sub_start],
                                               token_end[head_event.end - sub_start],
                                               token_start[tail_event.start - sub_start],
                                               token_end[tail_event.end - sub_start]]
                marker_mask += 1
                input_mask = input_mask + [marker_mask] * 4
                sub_obj_ids.append([len(tokens) - 4, len(tokens) - 2])

        input_ids = torch.tensor(tok.convert_tokens_to_ids(tokens), dtype=torch.long)
        position_ids = torch.tensor(position_ids, dtype=torch.long)
        segment_ids = torch.tensor([0] * len(input_ids), dtype=torch.long)
        input_mask = torch.tensor(get_attention_mask(input_mask), dtype=torch.long)
        sub_obj_ids = torch.tensor(sub_obj_ids, dtype=torch.long)
        all_input_ids.append(input_ids)
        all_position_ids.append(position_ids)
        all_input_mask.append(input_mask)
        all_segment_ids.append(segment_ids)
        all_sub_obj_ids.append(sub_obj_ids)

    input_ids = pad_sequence(all_input_ids, batch_first=True).to(device)
    position_ids = pad_sequence(all_position_ids, batch_first=True).to(device)
    batch_input_mask = all_input_mask
    if len(batch_input_mask) > 1:
        batch_max_len = max(x.size(0) for x in batch_input_mask)
        for j in range(len(batch_input_mask)):
            curr_mask_size = batch_input_mask[j].size(0)
            if curr_mask_size < batch_max_len:
                p2d = (0, batch_max_len - curr_mask_size, 0, batch_max_len - curr_mask_size)
                batch_input_mask[j] = F.pad(batch_input_mask[j], p2d, "constant", 0)

    input_mask = pad_sequence(batch_input_mask, batch_first=True).to(device)
    segment_ids = pad_sequence(all_segment_ids, batch_first=True).to(device)
    length = input_ids.size(1)
    sub_obj_ids = []
    sub_obj_ids_count = []
    for sub_obj_id in all_sub_obj_ids:
        sub_obj_ids_count.append(len(sub_obj_id))
        ids_padding = torch.tensor([[0, 0]] * (length // 4 - len(sub_obj_id)))
        sub_obj_id = torch.cat([sub_obj_id, ids_padding], dim=0)
        sub_obj_ids.append(sub_obj_id)
    sub_obj_ids = torch.stack(sub_obj_ids, dim=0).to(device)
    with torch.no_grad():
        logits = nre_model(input_ids, segment_ids, input_mask, sub_obj_ids=sub_obj_ids,
                           input_position=position_ids)
        logits = torch.softmax(logits, dim=-1)
    preds, scores = [], []
    for j, logit in enumerate(logits):
        scores_, preds_ = logit.max(-1)
        for x in range(sub_obj_ids_count[j]):
            preds.append(nre_id2label[preds_[x].item()])
            scores.append(scores_[x])

    pred_relation_set = set()
    for (head_event, tail_event), pred_type, score, in zip(pair_events, preds, scores):
        if pred_type != "none":
            pred_relation_set.add(
                f"{head_event.start},{head_event.end + 1}-{head_event.value}|{pred_type}|" \
                f"{tail_event.start},{tail_event.end + 1}-{tail_event.value}"
            )

    cal_cpg(pred_entity_set, gt_entity_set, ent_cpg)
    cal_cpg(pred_relation_set, gt_relation_set, rel_cpg)

ent_precision, ent_recall, ent_f1 = get_prf_scores(*ent_cpg)
rel_precision, rel_recall, rel_f1 = get_prf_scores(*rel_cpg)
print(f"ent: p={ent_precision}, r={ent_recall}, f1={ent_f1}")
print(f"rel: p={rel_precision}, r={rel_recall}, f1={rel_f1}")
